
from PIL import Image
import numpy as np
import torch

from diffusers import StableDiffusionPipeline, DDIMScheduler

device = torch.device("cuda")


def load_model():
    """
    Load Stable Diffusion model.

    Returns
    -----------------------------
     - model
        Loaded SD model
    """
    scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
    model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler).to(device)
    return model


def load_resize_image(path: str, size: int = 512):
    """
    Load images of the given path, and resize it.

    Parameters
    -----------------------------
     - path : str
        The path of a iamge to load
     - size : int (default is 512)
        The size to resize the image

    Returns
    -----------------------------
     - image : PIL.Image
        The loaded and resized image
    """
    image = Image.open(path)
    W, H = image.size

    # when the image is not square size, do center-crop
    if not W == H:
        S = min(W, H)
        left = (max(W, H) - H) // 2
        top = (max(W, H) - W) // 2
        image = Image.fromarray(np.array(image)[top:top+S, left:left+S])

    if size is not None:
        image = image.resize((size, size))
    return image


@torch.no_grad()
def negative_prompt_inversion(model, image_pil: Image, prompt: str, num_ddim_steps: int = 50):
    """
    Negative-prompt inversion
    Calculating $z_T$ by DDIM inversion and the prompt embedding $C$ as the null-text.
    This implementation mainly based on the implementation of 'diffusers'.

    Parameters
    -----------------------------
     - model
        SD model
     - image_pil : PIL.Image
        the image to reconstruct
     - prompt : str
        the prompt corresponding to the given image
     - num_ddim_steps : int (default is 50)
        the number of sampling steps

    Returns
    -----------------------------
     - latent : torch.Tensor(1, 4, 64, 64)
        the torch.tensor corresponding to $z_T$
     - cond_embed : torch.Tensor(1, 77, 768)
        the torch.tensor corresponding to $C$
    """
    cond_embed = model._encode_prompt(prompt, device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None)

    # Calculate $z_0$ from the given image.
    image = torch.from_numpy(np.array(image_pil) / 127.5 - 1).float().permute(2, 0, 1).unsqueeze(0).to(device)
    latent = 0.18215 * model.vae.encode(image)["latent_dist"].mean

    model.scheduler.set_timesteps(num_ddim_steps)
    T = model.scheduler.config.num_train_timesteps

    # DDIM inversion
    for i in range(num_ddim_steps):
        t_next = model.scheduler.timesteps[-1-i]
        t = t_next - T // num_ddim_steps

        alpha_prod_t = model.scheduler.alphas_cumprod[t] if t >= 0 else model.scheduler.final_alpha_cumprod
        alpha_prod_t_next = model.scheduler.alphas_cumprod[t_next]
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_next = 1 - alpha_prod_t_next

        model_output = model.unet(latent, t_next, encoder_hidden_states=cond_embed).sample

        original_pred = (latent - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5)
        latent = (alpha_prod_t_next**0.5) * original_pred + (beta_prod_t_next**0.5) * model_output

    return latent, cond_embed


@torch.no_grad()
def generate(
    model, prompt: str, uncond_embed, latent: torch.Tensor,
    num_ddim_steps: int = 50, guidance_scale: float = 7.5,
):
    """
    Calculating $z_0$ by DDIM sampling with CFG from $z_T$.
    This implementation mainly based on the implementation of 'diffusers'.

    Parameters
    -----------------------------
     - model
        SD model
     - prompt : str
        the prompt corresponding to the given image
     - uncond_embed : list[torch.Tensor] or torch.Tensor
        The embedding for unconditional prediction.
        If it is list, the sampling step is used as the index.
     - latent : torch.Tensor
        The torch.tensor corresponding to $z_T$.
     - num_ddim_steps : int (default is 50)
        the number of sampling steps
     - guidance_scale : float (default is 7.5)
        the value of the CFG parameter $w$

    Returns
    -----------------------------
     - image : PIL.Image
        The generated image
    """
    prompt_embed = model._encode_prompt(prompt, device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None)

    model.scheduler.set_timesteps(num_ddim_steps)

    for i, t in enumerate(model.scheduler.timesteps):
        latent_model_input = torch.cat([latent] * 2)

        if isinstance(uncond_embed, (list, tuple)):
            uncond_embed_ = uncond_embed[i]
        else:
            uncond_embed_ = uncond_embed

        noise_pred = model.unet(
            latent_model_input,
            t,
            encoder_hidden_states=torch.cat([uncond_embed_, prompt_embed]),
        ).sample

        # Classifier-free Guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latent = model.scheduler.step(noise_pred, t, latent).prev_sample

    image = model.decode_latents(latent)
    image = model.numpy_to_pil(image)[-1]
    return image
